// -*- mode:c++ -*-

// Copyright (c) 2010 ARM Limited
// All rights reserved
//
// The license below extends only to copyright in the software and shall
// not be construed as granting a license to any other intellectual
// property including but not limited to intellectual property relating
// to a hardware implementation of the functionality of the software
// licensed hereunder.  You may use the software subject to the license
// terms below provided that you ensure that this notice is replicated
// unmodified and in its entirety in all distributions of the software,
// modified or unmodified, in source code or in binary form.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met: redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer;
// redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution;
// neither the name of the copyright holders nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

let {{

    header_output = ""
    decoder_output = ""
    exec_output = ""

    calcQCode = '''
        CpsrQ = (resTemp & 1) << 27;
    '''

    calcCcCode = '''
        uint16_t _iz, _in;
        _in = (resTemp >> %(negBit)d) & 1;
        _iz = ((%(zType)s)resTemp == 0);

        CondCodesNZ = (_in << 1) | _iz;

        DPRINTF(Arm, "(in, iz) = (%%d, %%d)\\n", _in, _iz);
       '''

    def buildMultInst(mnem, doCc, unCc, regs, code, flagType):
        global header_output, decoder_output, exec_output
        cCode = carryCode[flagType]
        vCode = overflowCode[flagType]
        zType = "uint32_t"
        negBit = 31
        if flagType == "llbit":
            zType = "uint64_t"
            negBit = 63
        if flagType == "overflow":
            ccCode = calcQCode
        else:
            ccCode = calcCcCode % {
                "negBit": negBit,
                "zType": zType
            }

        if not regs in (3, 4):
            raise Exception("Multiplication instructions with {} ".format(
                regs) + "registers are not implemented")

        if regs == 3:
            base = 'Mult3'
        else:
            base = 'Mult4'

        Name = mnem.capitalize()

        if unCc:
            iop = ArmInstObjParams(mnem, Name, base,
                    { "code" : code, "predicate_test": pickPredicate(code),
                      "op_class": "IntMultOp" })
        if doCc:
            iopCc = ArmInstObjParams(mnem + "s", Name + "Cc", base,
                    { "code" : code + ccCode,
                      "predicate_test": pickPredicate(code + ccCode),
                      "op_class": "IntMultOp" })

        if regs == 3:
            declare = Mult3Declare
            constructor = Mult3Constructor
        else:
            declare = Mult4Declare
            constructor = Mult4Constructor

        if unCc:
            header_output += declare.subst(iop)
            decoder_output += constructor.subst(iop)
            exec_output += PredOpExecute.subst(iop)
        if doCc:
            header_output += declare.subst(iopCc)
            decoder_output += constructor.subst(iopCc)
            exec_output += PredOpExecute.subst(iopCc)

    def buildMult3Inst(mnem, code, flagType = "logic"):
        buildMultInst(mnem, True, True, 3, code, flagType)

    def buildMult3InstCc(mnem, code, flagType = "logic"):
        buildMultInst(mnem, True, False, 3, code, flagType)

    def buildMult3InstUnCc(mnem, code, flagType = "logic"):
        buildMultInst(mnem, False, True, 3, code, flagType)

    def buildMult4Inst(mnem, code, flagType = "logic"):
        buildMultInst(mnem, True, True, 4, code, flagType)

    def buildMult4InstCc(mnem, code, flagType = "logic"):
        buildMultInst(mnem, True, False, 4, code, flagType)

    def buildMult4InstUnCc(mnem, code, flagType = "logic"):
        buildMultInst(mnem, False, True, 4, code, flagType)

    buildMult4Inst("mla", "Reg0 = resTemp = Reg1 * Reg2 + Reg3;")
    buildMult4InstUnCc("mls", "Reg0 = Reg3 - Reg1 * Reg2;")
    buildMult3Inst("mul", "Reg0 = resTemp = Reg1 * Reg2;")
    buildMult4InstCc("smlabb", '''
            PInt0 = resTemp = PInt1.sh0 * PInt2.sh0 + PInt3.sw;
            resTemp = bits(resTemp, 32) != bits(resTemp, 31);
        ''', "overflow")
    buildMult4InstCc("smlabt", '''
            PInt0 = resTemp = PInt1.sh0 * PInt2.sh1 + PInt3.sw;
            resTemp = bits(resTemp, 32) != bits(resTemp, 31);
        ''', "overflow")
    buildMult4InstCc("smlatb", '''
            PInt0 = resTemp = PInt1.sh1 * PInt2.sh0 + PInt3.sw;
            resTemp = bits(resTemp, 32) != bits(resTemp, 31);
        ''', "overflow")
    buildMult4InstCc("smlatt", '''
            PInt0 = resTemp = PInt1.sh1 * PInt2.sh1 + PInt3.sw;
            resTemp = bits(resTemp, 32) != bits(resTemp, 31);
        ''', "overflow")
    buildMult4InstCc("smlad", '''
            PInt0 = resTemp = PInt1.sh1 * PInt2.sh1 +
                              PInt1.sh0 * PInt2.sh0 + PInt3.sw;
            resTemp = bits(resTemp, 32) != bits(resTemp, 31);
        ''', "overflow")
    buildMult4InstCc("smladx", '''
            PInt0 = resTemp = PInt1.sh1 * PInt2.sh0 +
                              PInt1.sh0 * PInt2.sh1 + PInt3.sw;
            resTemp = bits(resTemp, 32) != bits(resTemp, 31);
        ''', "overflow")
    buildMult4Inst("smlal", '''
            resTemp = PInt2.sw * PInt3.sw +
                      (int64_t)(((uint64_t)PInt1.uw << 32) | PInt0.uw);
            PInt0 = (uint32_t)resTemp;
            PInt1 = (uint32_t)(resTemp >> 32);
        ''', "llbit")
    buildMult4InstUnCc("smlalbb", '''
            resTemp = PInt2.sh0 * PInt3.sh0 +
                      (int64_t)(((uint64_t)PInt1.uw << 32) | PInt0.uw);
            PInt0 = (uint32_t)resTemp;
            PInt1 = (uint32_t)(resTemp >> 32);
        ''')
    buildMult4InstUnCc("smlalbt", '''
            resTemp = PInt2.sh0 * PInt3.sh1 +
                      (int64_t)(((uint64_t)PInt1.uw << 32) | PInt0.uw);
            PInt0 = (uint32_t)resTemp;
            PInt1 = (uint32_t)(resTemp >> 32);
        ''')
    buildMult4InstUnCc("smlaltb", '''
            resTemp = PInt2.sh1 * PInt3.sh0 +
                      (int64_t)(((uint64_t)PInt1.uw << 32) | PInt0.uw);
            PInt0 = (uint32_t)resTemp;
            PInt1 = (uint32_t)(resTemp >> 32);
        ''')
    buildMult4InstUnCc("smlaltt", '''
            resTemp = PInt2.sh1 * PInt3.sh1 +
                      (int64_t)(((uint64_t)PInt1.uw << 32) | PInt0.uw);
            PInt0 = (uint32_t)resTemp;
            PInt1 = (uint32_t)(resTemp >> 32);
        ''')
    buildMult4InstUnCc("smlald", '''
            resTemp = PInt2.sh1 * PInt3.sh1 + PInt2.sh0 * PInt3.sh0 +
                      (int64_t)(((uint64_t)PInt1.uw << 32) | PInt0.uw);
            PInt0 = (uint32_t)resTemp;
            PInt1 = (uint32_t)(resTemp >> 32);
        ''')
    buildMult4InstUnCc("smlaldx", '''
            resTemp = PInt2.sh1 * PInt3.sh0 + PInt2.sh0 * PInt3.sh1 +
                      (int64_t)(((uint64_t)PInt1.uw << 32) | PInt0.uw);
            PInt0 = (uint32_t)resTemp;
            PInt1 = (uint32_t)(resTemp >> 32);
        ''')
    buildMult4InstCc("smlawb", '''
            resTemp = PInt1.sw * PInt2.sh0 + (PInt3.sw << 16);
            PInt0 = resTemp = resTemp >> 16;
            resTemp = bits(resTemp, 32) != bits(resTemp, 31);
        ''', "overflow")
    buildMult4InstCc("smlawt", '''
            resTemp = PInt1.sw * PInt2.sh1 + (PInt3.sw << 16);
            PInt0 = resTemp = resTemp >> 16;
            resTemp = bits(resTemp, 32) != bits(resTemp, 31);
        ''', "overflow")
    buildMult4InstCc("smlsd", '''
            PInt0 = resTemp = PInt1.sh0 * PInt2.sh0 -
                              PInt1.sh1 * PInt2.sh1 + PInt3.sw;
            resTemp = bits(resTemp, 32) != bits(resTemp, 31);
        ''', "overflow")
    buildMult4InstCc("smlsdx", '''
            PInt0 = resTemp = PInt1.sh0 * PInt2.sh1 -
                              PInt1.sh1 * PInt2.sh0 + PInt3.sw;
            resTemp = bits(resTemp, 32) != bits(resTemp, 31);
        ''', "overflow")
    buildMult4InstUnCc("smlsld", '''
            resTemp = PInt2.sh0 * PInt3.sh0 - PInt2.sh1 * PInt3.sh1 +
                      (int64_t)(((uint64_t)PInt1.uw << 32) | PInt0.uw);
            PInt0 = (uint32_t)resTemp;
            PInt1 = (uint32_t)(resTemp >> 32);
        ''')
    buildMult4InstUnCc("smlsldx", '''
            resTemp = PInt2.sh0 * PInt3.sh1 - PInt2.sh1 * PInt3.sh0 +
                      (int64_t)(((uint64_t)PInt1.uw << 32) | PInt0.uw);
            PInt0 = (uint32_t)resTemp;
            PInt1 = (uint32_t)(resTemp >> 32);
        ''')
    buildMult4InstUnCc("smmla", '''
            PInt0 = (((int64_t)PInt3.sw << 32) + (PInt1.sw * PInt2.sw)) >> 32;
        ''')
    buildMult4InstUnCc("smmlar", '''
            PInt0 = (((int64_t)PInt3.sw << 32) +
                     (PInt1.sw * PInt2.sw) +
                     (0x1ULL << 31)) >> 32;
        ''')
    buildMult4InstUnCc("smmls", '''
            PInt0 = (((int64_t)PInt3.sw << 32) - (PInt1.sw * PInt2.sw)) >> 32;
        ''')
    buildMult4InstUnCc("smmlsr", '''
            PInt0 = (((int64_t)PInt3.sw << 32) -
                     (PInt1.sw * PInt2.sw) +
                     (0x1ULL << 31)) >> 32;
        ''')
    buildMult3InstUnCc("smmul", '''
            PInt0 = ((int64_t)PInt1.sw * PInt2.sw) >> 32;
        ''')
    buildMult3InstUnCc("smmulr", '''
            PInt0 = (((int64_t)PInt1.sw * PInt2.sw) + (0x1ULL << 31)) >> 32;
        ''')
    buildMult3InstCc("smuad", '''
            PInt0 = resTemp = PInt1.sh0 * PInt2.sh0 + PInt1.sh1 * PInt2.sh1;
            resTemp = bits(resTemp, 32) != bits(resTemp, 31);
        ''', "overflow")
    buildMult3InstCc("smuadx", '''
            PInt0 = resTemp = PInt1.sh0 * PInt2.sh1 + PInt1.sh1 * PInt2.sh0;
            resTemp = bits(resTemp, 32) != bits(resTemp, 31);
        ''', "overflow")
    buildMult3InstUnCc("smulbb", '''PInt0 = PInt1.sh0 * PInt2.sh0;''')
    buildMult3InstUnCc("smulbt", '''PInt0 = PInt1.sh0 * PInt2.sh1;''')
    buildMult3InstUnCc("smultb", '''PInt0 = PInt1.sh1 * PInt2.sh0;''')
    buildMult3InstUnCc("smultt", '''PInt0 = PInt1.sh1 * PInt2.sh1;''')
    buildMult4Inst("smull", '''
            resTemp = PInt2.sw * PInt3.sw;
            PInt0 = (int32_t)resTemp;
            PInt1 = (int32_t)(resTemp >> 32);
        ''', "llbit")
    buildMult3InstUnCc("smulwb", '''PInt0 = (PInt1.sw * PInt2.sh0) >> 16;''')
    buildMult3InstUnCc("smulwt", '''PInt0 = (PInt1.sw * PInt2.sh1) >> 16;''')
    buildMult3InstUnCc("smusd", '''
            PInt0 = PInt1.sh0 * PInt2.sh0 - PInt1.sh1 * PInt2.sh1;
        ''')
    buildMult3InstUnCc("smusdx", '''
            PInt0 = PInt1.sh0 * PInt2.sh1 - PInt1.sh1 * PInt2.sh0;
        ''')
    buildMult4InstUnCc("umaal", '''
            resTemp = PInt2.uw * PInt3.uw + PInt0.uw + PInt1.uw;
            PInt0 = (uint32_t)resTemp;
            PInt1 = (uint32_t)(resTemp >> 32);
        ''')
    buildMult4Inst("umlal", '''
            resTemp = PInt2.uw * PInt3.uw + PInt0.uw + (PInt1.uw << 32);
            PInt0 = (uint32_t)resTemp;
            PInt1 = (uint32_t)(resTemp >> 32);
        ''', "llbit")
    buildMult4Inst("umull", '''
            resTemp = PInt2.uw * PInt3.uw;
            PInt0 = (uint32_t)resTemp;
            PInt1 = (uint32_t)(resTemp >> 32);
        ''', "llbit")
}};
